package com.tsfyun.scm.config;

import com.tsfyun.common.base.dto.Result;
import com.tsfyun.common.base.enums.ResultCodeEnum;
import com.tsfyun.common.base.support.HttpServletRequestWrapper;
import com.tsfyun.common.base.util.ResultUtil;
import com.tsfyun.common.base.util.SQLFilter;
import com.tsfyun.common.base.util.StringUtils;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import org.springframework.web.servlet.handler.HandlerInterceptorAdapter;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.util.Arrays;
import java.util.Enumeration;
import java.util.Objects;

/**
 * SQL注入拦截器
 */
@Component
@Slf4j
public class SqlInterceptor extends HandlerInterceptorAdapter {

    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
        boolean noInject = true;
        String method = request.getMethod();
        String contentType = StringUtils.null2EmptyWithTrim(request.getContentType());
        HttpServletRequestWrapper sqlRequest;
        if (request instanceof HttpServletRequest) {
            sqlRequest = new HttpServletRequestWrapper(request);
        } else {
            return true;
        }
        try {
            //JSON请求
            if ("POST".equalsIgnoreCase(method) && contentType.contains("application/json")) {
                //此处需要注意流只能读一次
                HttpServletRequestWrapper wrapper = new HttpServletRequestWrapper(request);
                String requestBody = wrapper.getRequestBodyParame();
                if (SQLFilter.sqlInject(requestBody)) {
                    noInject = false;
                }
            } else {
                Enumeration<String> paramNames = sqlRequest.getParameterNames();
                while (paramNames.hasMoreElements()) {
                    String name = paramNames.nextElement();
                    String[] values = request.getParameterValues(name);
                    if (Objects.nonNull(values) && values.length > 0) {
                        long injectCnt = Arrays.asList(values).stream().filter(r -> SQLFilter.sqlInject(r.toLowerCase())).count();
                        if (injectCnt > 0) {
                            noInject = false;
                        }
                    }
                }
            }
        } catch (Exception e) {
            log.error("校验SQL注入异常",e);
            noInject = true;
        }
        if(!noInject) {
            Result result = Result.error(ResultCodeEnum.FAIL.getCode(),"填写数据包含非法字符");
            ResultUtil.errorBack(response,result);
        }
        return noInject;
    }



}
